{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyM+98apepmO9YHdeIxM5wlu",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/sugarforever/LangChain-Tutorials/blob/main/expression-language/LangChain_Expression_03_Router.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 03 路由（Router）\n",
        "\n",
        "`LEL` ( `LangChain Expression Language` ) 实现了路由机制，支持在应用开发中根据业务需要将请求转发给指定的链。\n",
        "\n",
        "核心类：`RouterRunnable`"
      ],
      "metadata": {
        "id": "OHpaChk4b5UN"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 准备环境\n",
        "\n",
        "安装必要的 `python` 包。"
      ],
      "metadata": {
        "id": "IAzbVA50cVt1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install -q -U langchain openai"
      ],
      "metadata": {
        "id": "MZWl8x3dmjsV"
      },
      "execution_count": 27,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain.prompts import ChatPromptTemplate\n",
        "from langchain.schema.runnable import RunnableMap, Runnable, RunnableConfig, RunnablePassthrough, Input\n",
        "from langchain.load.serializable import Serializable\n",
        "from langchain.prompts import ChatPromptTemplate\n",
        "from langchain.chat_models import ChatOpenAI\n",
        "from langchain.schema import StrOutputParser\n",
        "from operator import itemgetter\n",
        "from typing import Optional, Dict"
      ],
      "metadata": {
        "id": "pYpbezvYwzxU"
      },
      "execution_count": 30,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 一个有用的调试类 `StdOutputRunnable`\n",
        "\n",
        "类 `StdOutputRunnable` 的功能与 `RunnablePassthrough` 几乎一致，唯一的区别是打印出输入的数据。这非常适合于 `LEL` 的调试。将其加到管道中，可以查看所在环节的输入值，帮助调试。"
      ],
      "metadata": {
        "id": "yf8R7S7q3Dll"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class StdOutputRunnable(Serializable, Runnable[Input, Input]):\n",
        "    @property\n",
        "    def lc_serializable(self) -> bool:\n",
        "        return True\n",
        "\n",
        "    def invoke(self, input: Dict, config: Optional[RunnableConfig] = None) -> Input:\n",
        "        print(input);\n",
        "        return self._call_with_config(lambda x: x, input, config)\n"
      ],
      "metadata": {
        "id": "xkgc0euJ0Dnz"
      },
      "execution_count": 24,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### RouterRunnable 示例"
      ],
      "metadata": {
        "id": "YjY2hgpq37xR"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "模型辅助类，在分类任务执行时，指定允许的分类值。\n",
        "\n",
        "`create_tagging_chain_pydantic` 创建基于 `Pydantic` schema定义的分类链，请参考 [API文档](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_functions.tagging.create_tagging_chain_pydantic.html)"
      ],
      "metadata": {
        "id": "fcizgwUh4EBZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "\n",
        "os.environ['OPENAI_API_KEY'] = \"您的有效openai api key\""
      ],
      "metadata": {
        "id": "oYHp-mBk4dO5"
      },
      "execution_count": 29,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain.chains import create_tagging_chain_pydantic\n",
        "from pydantic import BaseModel, Field\n",
        "\n",
        "class ChainToUse(BaseModel):\n",
        "    \"\"\"Used to determine which chain to serve the user.\"\"\"\n",
        "\n",
        "    name: str = Field(description=\"Should be one of `color` or `fruit`\")\n",
        "\n",
        "tagger = create_tagging_chain_pydantic(ChainToUse, ChatOpenAI(temperature=0))"
      ],
      "metadata": {
        "id": "2RkcxjTV3hW4"
      },
      "execution_count": 40,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model = ChatOpenAI()"
      ],
      "metadata": {
        "id": "kQ-MyoMT4n6o"
      },
      "execution_count": 41,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "color_chain = ChatPromptTemplate.from_template(\"You are a color expert. Answer the question about color: {question}\") | model\n",
        "fruit_chain = ChatPromptTemplate.from_template(\"You are a fruit expert. Answer the question about fruit: {question}\") | model"
      ],
      "metadata": {
        "id": "3C0IOKs44mjz"
      },
      "execution_count": 42,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain.schema.runnable import RouterRunnable\n",
        "router = RouterRunnable({\"color\": color_chain, \"fruit\": fruit_chain})"
      ],
      "metadata": {
        "id": "EPj6CgAoxP4m"
      },
      "execution_count": 43,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "chain = {\n",
        "    \"key\": RunnablePassthrough() | tagger | StdOutputRunnable() | (lambda x: x['text'].name),\n",
        "    \"input\": {\"question\": RunnablePassthrough()}\n",
        "} | router"
      ],
      "metadata": {
        "id": "AsobBwvR5D2M"
      },
      "execution_count": 46,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "chain.invoke(\"What is the HEX code of color YELLOW?\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2q3tEt8n5Fsd",
        "outputId": "417c29f4-c48a-4827-f367-d587944dc449"
      },
      "execution_count": 47,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'input': 'What is the HEX code of color YELLOW?', 'text': ChainToUse(name='color')}\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "AIMessage(content='The HEX code for the color yellow is #FFFF00.', additional_kwargs={}, example=False)"
            ]
          },
          "metadata": {},
          "execution_count": 47
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "chain.invoke(\"What country grow most apples in 2000?\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3TMgF_t_5gST",
        "outputId": "0cf65277-49f8-4838-a6d0-814ca69d41c2"
      },
      "execution_count": 48,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'input': 'What country grow most apples in 2000?', 'text': ChainToUse(name='fruit')}\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "AIMessage(content='In the year 2000, China was the country that grew the most apples.', additional_kwargs={}, example=False)"
            ]
          },
          "metadata": {},
          "execution_count": 48
        }
      ]
    }
  ]
}